# Change the path to the location of Earwigs.csv on your computer
EW <- read_csv("../Data/Earwigs.csv", show_col_types = FALSE)
ggplot() +
geom_point(
data = EW,
aes(Density, Proportion_forceps),
color = "steelblue",
size = 3
)Problem Set 4
In this problem set, before we get to the models, we want give some demonstration code about two packages that help with some aspects of Bayesian inference. We’ll walk through code so you have templates to use for later analyses in this problem set.
priorsense Package
Usually, those new to Bayesian analysis have many questions about how to set a prior and how to evaluate whether the chosen prior is appropriate. We will show you a new package that helps to assess this second question - is my prior appropriate? This package is the priorsense package1. If you are interested is a less technical introduction, here is a short video describing the package. We decided to leave this out of the lectures in part because it would have taken time away from other topics and because it is an area of still active research.
The approach that Kallioinen et al. take is to use importance sampling (as in PSIS-LOO-CV) of the prior or likelihood raised to exponent (the “power” in power-scaling) to detect instances of prior-data conflict wherein the prior contains too much information (e.g., is too constrained) or the likelihood (data) has too little information, or some combination of the two.
The approach is simply to give either the prior or the likelihood more (or less) power by raising it to an exponent \(\alpha\) that varies around 1 (i.e., no scaling). For example, in testing the prior \(Pr(\theta)\) sensitivity:
\[Pr(\theta|y) \sim Pr(y|\theta) Pr(\theta)^\alpha\]
the prior is raised to an exponent \(\alpha\) that can vary. The response of the posterior \(Pr(\theta|y)\) to changing the strength of the prior tells us how sensitive the model is to the prior. Note here that we are only looking at the numerator of Bayes’ Rule, because MCMC methods make dealing with the probability of the data \(Pr(y)\) unnecessary.
The prior scaling approach is a complementary to prior predictive simulation that we have been using so far. The general approach would be to develop priors via prior predictive simulation (using the different options in ulam() or brm() to sample from the prior only) and then check that those priors are adequate using the functions in priorsense.
The main function of this package are:
powerscale_sequence()evaluates the prior/likelihood sensitivity across a range of powers. This function can be wrapped in eitherpowerscale_plot_dens()orpowerscale_plot_quantities()to plot changes in the posterior densities or the dependency of the posterior on prior or likelihood scaling, respectively.powerscale_sensitivity()is the main function to test the sensitivity of the prior and likelihood via power-scaling.
Install priorsense from CRAN.
We will use the Earwigs data from Problem Set 3 to explore how to use this package.
Load the data and plot:
Check the variable that can have priors in the brm() model:
get_prior(Proportion_forceps ~ 1 + Density, data = EW) prior class coef group resp dpar nlpar lb ub tag
(flat) b
(flat) b Density
student_t(3, 0.3, 2.5) Intercept
student_t(3, 0, 2.5) sigma 0
source
default
(vectorized)
default
default
We will skip the prior predictive simulation in this example. To use the priorsense functions, you have to fit the model with the data (i.e., not sampling from the prior only). So in practice you would do the prior predictive simulation here, sampling from the prior to get a prospective set of priors. Then you would use the functions in priorsense to evaluate the priors with the model.
For now, we will set the priors to be very bad (which we know from doing Problem Set 3), to see what the diagnostics look like:
fm <- brm(
Proportion_forceps ~ 1 + Density,
data = EW,
prior = c(
prior(normal(0, 0.0001), class = b),
prior(normal(0, 0.0001), class = Intercept),
prior(normal(0, 0.01), class = sigma)
),
refresh = 0
)Start sampling
Running MCMC with 4 sequential chains...
Chain 1 finished in 0.0 seconds.
Chain 2 finished in 0.0 seconds.
Chain 3 finished in 0.1 seconds.
Chain 4 finished in 0.1 seconds.
All 4 chains finished successfully.
Mean chain execution time: 0.1 seconds.
Total execution time: 0.5 seconds.
Loading required namespace: rstan
Power-scale sensitivity visual diagnostics
powerscale_plot_dens() plots overlapping density plots color coded by the range of `\(\alpha\) power-scaling exponents in the tested range. When the lines overlap, it indicates that the density estimate is not sensitive to power-scaling. Considering the prior, for example, if the lines differ, then it means that the prior density changes when scaled by a power (bad). So what we want to see is that the density lines for the prior are superimposed. For the likelihood, the lines should not overlap, indicating that the prior and the data are able to change the likelihood.
Because the scales of the variable are so different, we will plot them separately. Notice the embedded powerscale_sequence(fm). We could pre-compute this and pass to the function just as well (if there were more data would be advantageous for speed).
powerscale_plot_dens(
powerscale_sequence(fm),
quantity = c("mean", "sd"),
variable = "b_Density"
)powerscale_plot_dens(
powerscale_sequence(fm),
quantity = c("mean", "sd"),
variable = "b_Intercept"
)powerscale_plot_dens(
powerscale_sequence(fm),
quantity = c("mean", "sd"),
variable = "sigma"
)You can see how the all of the priors are sensitive to scaling, particularly sigma. We also get messages about high Pareto \(k\) value, indicating poor fit.
powerscale_plot_quantities() visualizes the rate of change in the posterior as \(\alpha\) changes. Ideally we would like to see a flat-ish line for the prior, indicating that the prior is not sensitive to scaling. There are many options for the divergence measure, but the default “Cumulative Jensen-Shannon distance” (cjs_dist) seems to work fine.
powerscale_plot_quantities(
powerscale_sequence(fm),
quantity = c("mean", "sd"),
variable = c("b_Density", "b_Intercept", "sigma")
)Power-scale sensitivity table
Finally, the functionpowerscale_sensitivity() makes a table of sensitivity values. Values >0.05 indicates sensitivity of the prior or likelihood. The last column provides a diagnosis.
powerscale_sensitivity(fm)Sensitivity based on cjs_dist
Prior selection: all priors
Likelihood selection: all data
variable prior likelihood diagnosis
b_Intercept 0.086 0.023 potential strong prior / weak likelihood
b_Density 0.085 0.022 potential strong prior / weak likelihood
sigma 1.064 1.000 potential prior-data conflict
Intercept 0.088 0.025 potential strong prior / weak likelihood
Here we have a “weak likelihood” for the first two rows, because the priors on the means and Intercept are way too strong (Normal(0, 0.0001)). This means that the data are insufficient to move the likelihood away from the prior.
The prior for sigma is also poor, resulting in a prior-data conflict, where one goes up and one goes down as \(\alpha\) changes
Improved priors
Let’s use the priors that we developed for problem set 3 and hopefully see a better pattern.
fm <- brm(
Proportion_forceps ~ 1 + Density,
data = EW,
prior = c(
prior(normal(0, 0.1), class = b),
prior(normal(0, 1), class = Intercept),
prior(normal(0, 1), class = sigma)
),
refresh = 0,
iter = 5e3
)Start sampling
Running MCMC with 4 sequential chains...
Chain 1 finished in 0.0 seconds.
Chain 2 finished in 0.0 seconds.
Chain 3 finished in 0.0 seconds.
Chain 4 finished in 0.0 seconds.
All 4 chains finished successfully.
Mean chain execution time: 0.0 seconds.
Total execution time: 0.5 seconds.
Evaluating the priors:
powerscale_plot_dens(
powerscale_sequence(fm),
quantity = c("mean", "sd"),
variable = "b_Density"
)powerscale_plot_dens(
powerscale_sequence(fm),
quantity = c("mean", "sd"),
variable = "b_Intercept"
)powerscale_plot_dens(
powerscale_sequence(fm),
quantity = c("mean", "sd"),
variable = "sigma"
)powerscale_plot_quantities(
powerscale_sequence(fm),
quantity = c("mean", "sd"),
variable = c("b_Density", "b_Intercept", "sigma")
)powerscale_sensitivity(fm)Sensitivity based on cjs_dist
Prior selection: all priors
Likelihood selection: all data
variable prior likelihood diagnosis
b_Intercept 0.001 0.101 -
b_Density 0.000 0.101 -
sigma 0.002 0.238 -
Intercept 0.002 0.098 -
Notice how in the density plots, densities of the priors are all overlapping (lack of sensitivity) and the posteriors are not overlapping (the data is able to inform the posterior). The quantity plot shows relatively flat lines for the priors and likelihoods that are sensitive to scaling. Finally the table has all values < 0.05 for the prior.
tidybayes Package
We want to add one more set of analysis tools to our general Bayesian inference kit: the tidybayes package. tidybayes has a variety of functions for extracting parts of fit models (from lots of model fitting interfaces), augmenting model fits with various kinds of predicted values, and making some very impressive visualizations.
The documentation has a [page of visualizations from brms models]http://mjskay.github.io/tidybayes/articles/tidy-brms.html().
tidybayes is particularly useful for working with the posteriors of multilevel models, which is what the demo code that the documentation provides is based on. Our usage here will be a little more pedestrian, but we can still see how useful the package can be.
We will adapt some of the demo code to plot the posterior for the earwigs model we just fit.
When you are trying to figure out what the variable names are in a model, the function get_variable() returns them:
library(tidybayes)
get_variables(fm) [1] "b_Intercept" "b_Density" "sigma" "Intercept"
[5] "lprior" "lp__" "accept_stat__" "treedepth__"
[9] "stepsize__" "divergent__" "n_leapfrog__" "energy__"
By default, brm() returns parameters prepended with b_ for “b” parameters (what are often called main or fixed effects) and r_ for random/multilevel effects (though we aren’t doing multilevel models in this module, it’s useful to know).
tidy_draws() is the simplest way to extract a posterior. You can see how it returns a lot of diagnostics as well: acceptance statistic, tree depth, step size, and whether that draw was a divergence or not.
tidybayes has a summary()-like function summarise_draws() (Commonwealth spelling only). We can pipe the output of tidy_draws() directly to it.
fm |> tidy_draws()# A tibble: 10,000 × 15
.chain .iteration .draw b_Intercept b_Density sigma Intercept lprior lp__
<int> <int> <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 1 1 1 0.109 0.00309 0.188 0.188 0.203 2.97
2 1 2 2 0.151 0.00297 0.188 0.227 0.195 5.17
3 1 3 3 0.155 0.00332 0.172 0.240 0.195 6.00
4 1 4 4 0.182 0.00603 0.183 0.337 0.164 7.23
5 1 5 5 0.193 0.00505 0.170 0.322 0.171 7.90
6 1 6 6 0.175 0.00459 0.151 0.292 0.184 7.99
7 1 7 7 0.210 0.00256 0.155 0.275 0.189 6.51
8 1 8 8 0.136 0.00606 0.138 0.291 0.185 7.38
9 1 9 9 0.169 0.00590 0.184 0.320 0.169 7.53
10 1 10 10 0.182 0.00310 0.189 0.261 0.187 6.48
# ℹ 9,990 more rows
# ℹ 6 more variables: accept_stat__ <dbl>, treedepth__ <dbl>, stepsize__ <dbl>,
# divergent__ <dbl>, n_leapfrog__ <dbl>, energy__ <dbl>
fm |> tidy_draws() |> summarise_draws()# A tibble: 12 × 10
variable mean median sd mad q5 q95 rhat ess_bulk
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 b_Interc… 0.172 0.173 0.0587 0.0567 0.0746 0.266 1.00 7607.
2 b_Density 0.00511 0.00511 0.00177 0.00170 0.00224 0.00802 1.00 10502.
3 sigma 0.173 0.169 0.0299 0.0275 0.132 0.227 1.00 6204.
4 Intercept 0.303 0.303 0.0373 0.0357 0.241 0.364 1.00 5536.
5 lprior 0.176 0.177 0.0130 0.0117 0.153 0.194 1.00 6149.
6 lp__ 6.54 6.88 1.31 1.06 3.94 7.95 1.00 4393.
7 accept_s… 0.916 0.955 0.109 0.0634 0.695 1 1.00 11953.
8 treedept… 2.32 2 0.556 0 2 3 1.01 6313.
9 stepsize… 0.614 0.616 0.0372 0.0394 0.561 0.665 Inf 4.01
10 divergen… 0 0 0 0 0 0 NA NA
11 n_leapfr… 5.17 7 2.08 0 3 7 1.01 7464.
12 energy__ -5.05 -5.38 1.78 1.66 -7.31 -1.72 1.00 3908.
# ℹ 1 more variable: ess_tail <dbl>
Many of the tidybayes functions require the posterior to be in a slightly different format, that of an rvar. An rvar is a compact way to store a distribution of values. We can use spread_rvars() to extract only a few of the columns. We then pipe that output to median_hdi() to get the median 89% HDI of the posterior for each.
fm |>
spread_rvars(b_Intercept, b_Density, sigma)# A tibble: 1 × 3
b_Intercept b_Density sigma
<rvar[1d]> <rvar[1d]> <rvar[1d]>
1 0.17 ± 0.059 0.0051 ± 0.0018 0.17 ± 0.03
fm |>
spread_rvars(b_Intercept, b_Density, sigma) |>
median_hdi(.width = 0.89)# A tibble: 1 × 12
b_Intercept b_Intercept.lower b_Intercept.upper b_Density b_Density.lower
<dbl> <dbl> <dbl> <dbl> <dbl>
1 0.173 0.0768 0.263 0.00511 0.00236
# ℹ 7 more variables: b_Density.upper <dbl>, sigma <dbl>, sigma.lower <dbl>,
# sigma.upper <dbl>, .width <dbl>, .point <chr>, .interval <chr>
There are many options in tidybayes to plot distributions and intervals. Here is a point + interval plot of the three main parameters.
fm |>
spread_rvars(b_Intercept, b_Density, sigma) |>
pivot_longer(cols = everything()) |>
ggplot(aes(y = name, dist = value)) +
stat_pointinterval(.width = c(0.89, 0.97))b_Density is very small relative to the other parameters, so it’s variation looks really small in comparison. We might just plot it separately:
fm |>
spread_rvars(b_Density) |>
pivot_longer(cols = everything()) |>
ggplot(aes(y = name, dist = value)) +
stat_pointinterval(.width = c(0.89, 0.97))From this plot you can see how the posterior is credibly different from 0, even though the parameter estimate is small.
tidybayes works well with data_grid() from the modelr package. Like crossing() that we have used before, data_grid() generates the pairwise combinations of variable that are passed to it, but without needing to include as many details (it will by default use the range of continuous variable and all the levels of factors).
If we then pipe that out to add_epred_draws() called with the fitted model, we can create a tibble with the values of Density across a range paired with the expected value of Proportion_forceps.
The second block of code pipes these values to ggplot() to make a plot of the observed data along with ribbons representing the 50%, 89%, and 97% HDIs for the expected values. Remember that these values do not include the standard deviation, so they are relatively narrow.
library(modelr)
# Expected parameter estimates
EW |>
data_grid(Density = seq_range(Density, n = 200)) |>
add_epred_draws(fm)# A tibble: 2,000,000 × 6
# Groups: Density, .row [200]
Density .row .chain .iteration .draw .epred
<dbl> <int> <int> <int> <int> <dbl>
1 0.152 1 NA NA 1 0.110
2 0.152 1 NA NA 2 0.151
3 0.152 1 NA NA 3 0.155
4 0.152 1 NA NA 4 0.183
5 0.152 1 NA NA 5 0.193
6 0.152 1 NA NA 6 0.176
7 0.152 1 NA NA 7 0.210
8 0.152 1 NA NA 8 0.137
9 0.152 1 NA NA 9 0.170
10 0.152 1 NA NA 10 0.182
# ℹ 1,999,990 more rows
EW |>
data_grid(Density = seq_range(Density, n = 200)) |>
add_epred_draws(fm) |>
ggplot(aes(x = Density, y = Proportion_forceps)) +
stat_lineribbon(aes(y = .epred), .width = c(0.5, 0.89, 0.97), alpha = 0.5) +
geom_point(data = EW)We can do the same but generate a posterior predictive distribution plot by calling add_predicted_draws() instead (note that the variable is .prediction rather than .epred).
# Posterior predictive distribution
EW |>
data_grid(Density = seq_range(Density, n = 200)) |>
add_predicted_draws(fm) |>
ggplot(aes(x = Density, y = Proportion_forceps)) +
stat_lineribbon(
aes(y = .prediction),
.width = c(0.5, 0.89, 0.97),
alpha = 0.5
) +
geom_point(data = EW)Almost all of the points fall within the 97% interval, just like we would predict. Observe that the lines and edges are pretty rough. We could sample more iterations to smooth those out.
In the analyses below, try to add the packages above to your now pretty well-developed Bayesian modeling routines. Also see if you can work with the mcmc_ functions from bayesplot and pp_check() for plotting prior/posterior predictive checks.
These are three models that you saw in Quantitative Methods 1. We will leave much of the details of the analysis to you, providing some guidance for three challenging kinds of models to fit and interpret.
ANOVA-like
The data in Heart_Transplants.csv has data on the Survival time (in days) for heart transplant patients with varying degrees of Mismatch between the donor and recipient. You will need to convert Mismatch to a factor and get the factor in the correct order: low, medium, high. Low indicates a relatively good match and high a poor match. The data have a pronounced right skew.
Load the data, visualize, and transform how you see fit.
# FIXME
HT <- read_csv("../Data/Heart_Transplants.csv", show_col_types = FALSE) |>
mutate(Mismatch = fct_inorder(Mismatch))
ggplot(HT, aes(x = Mismatch, y = Survival)) +
geom_point(position = position_jitter(width = 0.1, seed = 4564356)) +
stat_summary(fun = mean, geom = "point", size = 3, color = "red") +
stat_summary(
fun.data = mean_se,
geom = "errorbar",
width = 0.1,
linewidth = 1,
color = "red"
)HT |>
group_by(Mismatch) |>
summarize(mean_Survival = mean(Survival), sd_Survival = sd(Survival))# A tibble: 3 × 3
Mismatch mean_Survival sd_Survival
<fct> <dbl> <dbl>
1 Low 311. 431.
2 Medium 269 339.
3 High 71.9 86.0
HT <- HT |>
mutate(logSurvival = log(Survival))
ggplot(HT, aes(x = Mismatch, y = logSurvival)) +
geom_point(position = position_jitter(width = 0.1, seed = 4564356)) +
stat_summary(fun = mean, geom = "point", size = 3, color = "red") +
stat_summary(
fun.data = mean_se,
geom = "errorbar",
width = 0.1,
linewidth = 1,
color = "red"
)Model specification
\[\begin{align} \mathrm{logSurvival} & \sim Normal(\mu, \sigma) \\ \mu & = b[\mathrm{Mismatch}] \\ \end{align}\]
Prior specification and prior predictive check
# FIXME
get_prior(logSurvival ~ Mismatch - 1, data = HT) prior class coef group resp dpar nlpar lb ub tag
(flat) b
(flat) b MismatchHigh
(flat) b MismatchLow
(flat) b MismatchMedium
student_t(3, 0, 2.5) sigma 0
source
default
(vectorized)
(vectorized)
(vectorized)
default
PP <- brm(
logSurvival ~ Mismatch - 1,
data = HT,
prior = c(prior(normal(4, 3), class = b)),
refresh = 0,
sample_prior = "only"
)Start sampling
Running MCMC with 4 sequential chains...
Chain 1 finished in 0.0 seconds.
Chain 2 finished in 0.0 seconds.
Chain 3 finished in 0.0 seconds.
Chain 4 finished in 0.0 seconds.
All 4 chains finished successfully.
Mean chain execution time: 0.0 seconds.
Total execution time: 0.6 seconds.
pp_check(
PP,
type = "stat_grouped",
group = "Mismatch",
stat = "mean",
ndraws = 500,
binwidth = 0.5
)Note: in most cases the default test statistic 'mean' is too weak to detect anything of interest.
Final model specification
\[\begin{align} \mathrm{logSurvival} & \sim Normal(\mu, \sigma) \\ \mu & = b[\mathrm{Mismatch}] \\ b[\mathrm{Mismatch}] & \sim Normal(0, 3) \\ \sigma & \sim HalfNormal(0, 2) \end{align}\]
Sampling
# FIXME
fm <- brm(
logSurvival ~ Mismatch - 1,
data = HT,
prior = c(prior(normal(4, 3), class = b), prior(normal(0, 3), class = sigma)),
refresh = 0,
iter = 5e3
)Start sampling
Running MCMC with 4 sequential chains...
Chain 1 finished in 0.0 seconds.
Chain 2 finished in 0.0 seconds.
Chain 3 finished in 0.0 seconds.
Chain 4 finished in 0.0 seconds.
All 4 chains finished successfully.
Mean chain execution time: 0.0 seconds.
Total execution time: 0.6 seconds.
prior_summary(fm) prior class coef group resp dpar nlpar lb ub tag source
normal(4, 3) b user
normal(4, 3) b MismatchHigh (vectorized)
normal(4, 3) b MismatchLow (vectorized)
normal(4, 3) b MismatchMedium (vectorized)
normal(0, 3) sigma 0 user
Diagnostics
# FIXME
powerscale_plot_dens(
powerscale_sequence(fm),
quantity = c("mean", "sd"),
variable = c("b_MismatchLow", "b_MismatchMedium", "b_MismatchHigh")
)powerscale_plot_quantities(
powerscale_sequence(fm),
quantity = c("mean", "sd"),
variable = c("b_MismatchLow", "b_MismatchMedium", "b_MismatchHigh")
)powerscale_sensitivity(fm)Sensitivity based on cjs_dist
Prior selection: all priors
Likelihood selection: all data
variable prior likelihood diagnosis
b_MismatchLow 0.004 0.098 -
b_MismatchMedium 0.007 0.099 -
b_MismatchHigh 0.004 0.081 -
sigma 0.009 0.184 -
# FIXME
summary(fm) Family: gaussian
Links: mu = identity
Formula: logSurvival ~ Mismatch - 1
Data: HT (Number of observations: 39)
Draws: 4 chains, each with iter = 5000; warmup = 2500; thin = 1;
total post-warmup draws = 10000
Regression Coefficients:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
MismatchLow 4.49 0.43 3.65 5.33 1.00 10706 7480
MismatchMedium 4.78 0.44 3.92 5.64 1.00 10358 7624
MismatchHigh 3.73 0.46 2.85 4.65 1.00 10476 6640
Further Distributional Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma 1.61 0.20 1.27 2.04 1.00 8951 6829
Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
# FIXME
mcmc_trace(fm)mcmc_rank_overlay(fm)Posterior predictive simulation
# FIXME
pp_check(
fm,
type = "stat_grouped",
group = "Mismatch",
stat = "mean",
ndraws = 500,
binwidth = 0.1
)Note: in most cases the default test statistic 'mean' is too weak to detect anything of interest.
Summarizing the posterior
# FIXME
# Median HDI
fm |>
spread_rvars(b_MismatchLow, b_MismatchMedium, b_MismatchHigh) |>
set_names(distinct(HT, Mismatch) |> pull()) |>
median_hdi(.width = 0.89)# A tibble: 1 × 12
Low Low.lower Low.upper Medium Medium.lower Medium.upper High High.lower
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 4.49 3.76 5.13 4.78 4.07 5.47 3.73 3.00
# ℹ 4 more variables: High.upper <dbl>, .width <dbl>, .point <chr>,
# .interval <chr>
# Lots of different options for visualizing
fm |>
spread_rvars(b_MismatchLow, b_MismatchMedium, b_MismatchHigh) |>
set_names(distinct(HT, Mismatch) |> pull()) |>
pivot_longer(cols = everything()) |>
mutate(name = fct_inorder(name)) |>
ggplot(aes(y = name, dist = value)) +
stat_pointinterval(.width = c(0.89, 0.97))HT |>
data_grid(Mismatch) |>
add_predicted_draws(fm) |>
ggplot(aes(x = .prediction, y = Mismatch)) +
stat_slab()HT |>
data_grid(Mismatch) |>
add_predicted_draws(fm) |>
ggplot(aes(x = .prediction, y = Mismatch)) +
stat_interval(.width = c(0.50, 0.89, 0.97)) +
geom_point(aes(x = logSurvival), data = HT) +
scale_color_brewer()# Kruschke plot
library(distributional)
HT |>
data_grid(Mismatch) |>
add_epred_draws(fm, dpar = c("mu", "sigma")) |>
sample_draws(30) |>
ggplot(aes(y = Mismatch)) +
stat_slab(
aes(xdist = dist_normal(mu = mu, sigma = sigma)),
slab_color = "gray65",
alpha = 0.1,
fill = NA
) +
geom_point(
aes(x = logSurvival),
data = HT,
shape = 21,
fill = "#9ECAE1",
size = 3
)Test the hypothesis that Medium and High mismatch differ from Low using contrasts.
fm |>
spread_rvars(b_MismatchLow, b_MismatchMedium, b_MismatchHigh) |>
set_names(distinct(HT, Mismatch) |> pull()) |>
mutate(Med_v_Low = Medium - Low, High_v_Low = High - Low, .keep = "none") |>
pivot_longer(cols = everything()) |>
ggplot(aes(y = name, dist = value)) +
stat_pointinterval(.width = c(0.5, 0.89))2x2 factorial design
The file Bird_Plasma.xlsx contains factorial data on blood plasma calcium concentration (Calcium, in mg Ca per 100 mL plasma) in male and female birds (Sex) each of which was treated or not with a hormone (Treatment).
- Load the data, and convert hormone and sex to factors.
- The levels of
Treatmentare “Hormone” and “None”. RelevelTreatmentso that “None” is the base level. - Plot a reaction norm of Calcium vs. Sex, with color encoding Treatment to get a sense for the pattern.
# FIXME
BP <- readxl::read_excel("../Data/Bird_Plasma.xlsx") |>
mutate(
Treatment = factor(Treatment),
Sex = factor(Sex),
Treatment = fct_relevel(Treatment, "None")
)
BP |> count(Treatment, Sex)# A tibble: 4 × 3
Treatment Sex n
<fct> <fct> <int>
1 None Female 5
2 None Male 5
3 Hormone Female 5
4 Hormone Male 5
ggplot(BP, aes(x = Sex, y = Calcium, color = Treatment, group = Treatment)) +
geom_point(
position = position_jitter(width = 0.05, seed = 474577),
size = 3
) +
stat_summary(fun = mean, geom = "point", pch = 5, size = 5) +
stat_summary(fun = mean, geom = "line") +
scale_color_manual(values = c("gray50", "darkgreen"))Model specification
We have a factorial model, so we would like to model the two main effects: Sex and Treatment as well as the Sex by Treatment interaction term. Interactions between categorical variable are complicated to code in Bayesian models. Although you can just input the model like you would with lm(): Sex * Treatment, specifying the priors might be tricky and getting the posteriors sorted out as well.
One approach that works well in some (most? all?) situations is to create a new composite variable that combines the two other variable. Thus the four factorial groups become a single factor with four levels (Female-Hormone, Female-None, Male-Hormone, and Male-None). Because we are testing hypotheses using contrasts (subtracting posterior distributions), we don’t have to worry about the usual main effects and interaction P-value based hypothesis tests.
You can do this with a simple mutate, joining the two variable:
# FIXME
BP <- BP |>
mutate(Sex_Trt = paste(Sex, Treatment, sep = "_"))One additional advantage of this approach is that you only need to specify a single prior for all the groups.
\[\begin{align} \mathrm{Calcium} & \sim Normal(\mu, \sigma) \\ \mu & = b[\mathrm{Sex\_Trt}] \\ \end{align}\]
Prior specification and prior predictive check
There are only 5 points per group, so the prior is potentially very powerful relative to the likelihood.
# FIXME
get_prior(Calcium ~ Sex_Trt - 1, data = BP) prior class coef group resp dpar nlpar lb ub
(flat) b
(flat) b Sex_TrtFemale_Hormone
(flat) b Sex_TrtFemale_None
(flat) b Sex_TrtMale_Hormone
(flat) b Sex_TrtMale_None
student_t(3, 0, 12.1) sigma 0
tag source
default
(vectorized)
(vectorized)
(vectorized)
(vectorized)
default
PP <- brm(
Calcium ~ Sex_Trt - 1,
data = BP,
prior = c(
prior(normal(20, 15), class = b),
prior(normal(0, 10), class = sigma)
),
refresh = 0,
sample_prior = "only"
)Start sampling
Running MCMC with 4 sequential chains...
Chain 1 finished in 0.0 seconds.
Chain 2 finished in 0.0 seconds.
Chain 3 finished in 0.0 seconds.
Chain 4 finished in 0.0 seconds.
All 4 chains finished successfully.
Mean chain execution time: 0.0 seconds.
Total execution time: 0.6 seconds.
Warning: 2 of 4000 (0.0%) transitions ended with a divergence.
See https://mc-stan.org/misc/warnings for details.
# Check the minimum predicted vs. observed
pp_check(
PP,
type = "stat_grouped",
group = "Sex_Trt",
stat = "min",
ndraws = 500,
binwidth = 2
)# Check the mean predicted vs. observed
pp_check(
PP,
type = "stat_grouped",
group = "Sex_Trt",
stat = "mean",
ndraws = 500,
binwidth = 2
)Note: in most cases the default test statistic 'mean' is too weak to detect anything of interest.
# Check the maximum predicted vs. observed
pp_check(
PP,
type = "stat_grouped",
group = "Sex_Trt",
stat = "max",
ndraws = 500,
binwidth = 2
)Final model specification
\[\begin{align} \mathrm{Calcium} & \sim Normal(\mu, \sigma) \\ \mu & = b[\mathrm{Sex\_Trt}] \\ b[\mathrm{Sex\_Trt}] & \sim Normal(20, 15) \\ \sigma & \sim HalfNormal(0, 10) \end{align}\]
Sampling
# FIXME
fm <- brm(
Calcium ~ Sex_Trt - 1,
data = BP,
prior = c(
prior(normal(20, 15), class = b),
prior(normal(0, 10), class = sigma)
),
refresh = 0,
iter = 5e3
)Start sampling
Running MCMC with 4 sequential chains...
Chain 1 finished in 0.0 seconds.
Chain 2 finished in 0.0 seconds.
Chain 3 finished in 0.0 seconds.
Chain 4 finished in 0.0 seconds.
All 4 chains finished successfully.
Mean chain execution time: 0.0 seconds.
Total execution time: 0.6 seconds.
prior_summary(fm) prior class coef group resp dpar nlpar lb ub tag
normal(20, 15) b
normal(20, 15) b Sex_TrtFemale_Hormone
normal(20, 15) b Sex_TrtFemale_None
normal(20, 15) b Sex_TrtMale_Hormone
normal(20, 15) b Sex_TrtMale_None
normal(0, 10) sigma 0
source
user
(vectorized)
(vectorized)
(vectorized)
(vectorized)
user
Diagnostics
# FIXME
powerscale_plot_dens(
powerscale_sequence(fm),
quantity = c("mean", "sd"),
variable = c("b_Sex_TrtFemale_Hormone", "b_Sex_TrtFemale_None")
)powerscale_plot_dens(
powerscale_sequence(fm),
quantity = c("mean", "sd"),
variable = c("b_Sex_TrtMale_Hormone", "b_Sex_TrtMale_None")
)powerscale_plot_dens(
powerscale_sequence(fm),
quantity = c("mean", "sd"),
variable = c("sigma")
)powerscale_plot_quantities(
powerscale_sequence(fm),
quantity = c("mean", "sd"),
variable = c("b_Sex_TrtFemale_Hormone", "b_Sex_TrtFemale_None")
)powerscale_plot_quantities(
powerscale_sequence(fm),
quantity = c("mean", "sd"),
variable = c("b_Sex_TrtMale_Hormone", "b_Sex_TrtMale_None")
)powerscale_plot_quantities(
powerscale_sequence(fm),
quantity = c("mean", "sd"),
variable = c("sigma")
)powerscale_sensitivity(fm)Sensitivity based on cjs_dist
Prior selection: all priors
Likelihood selection: all data
variable prior likelihood diagnosis
b_Sex_TrtFemale_Hormone 0.016 0.115 -
b_Sex_TrtFemale_None 0.008 0.105 -
b_Sex_TrtMale_Hormone 0.012 0.106 -
b_Sex_TrtMale_None 0.010 0.117 -
sigma 0.013 0.320 -
# FIXME
summary(fm) Family: gaussian
Links: mu = identity
Formula: Calcium ~ Sex_Trt - 1
Data: BP (Number of observations: 20)
Draws: 4 chains, each with iter = 5000; warmup = 2500; thin = 1;
total post-warmup draws = 10000
Regression Coefficients:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
Sex_TrtFemale_Hormone 32.29 2.10 28.03 36.46 1.00 11413
Sex_TrtFemale_None 14.99 2.11 10.82 19.12 1.00 11449
Sex_TrtMale_Hormone 27.61 2.17 23.24 31.84 1.00 11358
Sex_TrtMale_None 12.25 2.11 8.11 16.39 1.00 11459
Tail_ESS
Sex_TrtFemale_Hormone 6567
Sex_TrtFemale_None 7407
Sex_TrtMale_Hormone 6075
Sex_TrtMale_None 7475
Further Distributional Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma 4.69 0.91 3.30 6.85 1.00 7725 7411
Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
BP |>
data_grid(Sex_Trt) |>
add_predicted_draws(fm) |>
ggplot(aes(x = .prediction, y = Sex_Trt)) +
stat_slab(alpha = 0.5, fill = "firebrick4")BP |>
data_grid(Sex_Trt) |>
add_predicted_draws(fm) |>
median_hdi(width = 0.89)# A tibble: 4 × 8
Sex_Trt .row width .lower .upper .width .point .interval
<chr> <int> <dbl> <dbl> <dbl> <dbl> <chr> <chr>
1 Female_Hormone 1 0.89 0.89 0.89 0.95 median hdi
2 Female_None 2 0.89 0.89 0.89 0.95 median hdi
3 Male_Hormone 3 0.89 0.89 0.89 0.95 median hdi
4 Male_None 4 0.89 0.89 0.89 0.95 median hdi
Posterior predictive simulation
# FIXME
# Check the minimum predicted vs. observed
pp_check(
fm,
type = "stat_grouped",
group = "Sex_Trt",
stat = "min",
ndraws = 500,
binwidth = 1
)# Check the mean predicted vs. observed
pp_check(
fm,
type = "stat_grouped",
group = "Sex_Trt",
stat = "mean",
ndraws = 500,
binwidth = 1
)Note: in most cases the default test statistic 'mean' is too weak to detect anything of interest.
# Check the maximum predicted vs. observed
pp_check(
fm,
type = "stat_grouped",
group = "Sex_Trt",
stat = "max",
ndraws = 500,
binwidth = 1
)Summarizing the posterior
Compare the means of Hormone vs. Control separately by sex.
# FIXME
post <- fm |>
spread_rvars(
b_Sex_TrtFemale_Hormone,
b_Sex_TrtFemale_None,
b_Sex_TrtMale_Hormone,
b_Sex_TrtMale_None
) |>
mutate(
`F: Horm. vs. C.` = b_Sex_TrtMale_Hormone - b_Sex_TrtMale_None,
`M: Horm. vs. C.` = b_Sex_TrtFemale_Hormone - b_Sex_TrtFemale_None,
.keep = "none"
)
post |>
pivot_longer(cols = everything()) |>
ggplot(aes(xdist = value, fill = name)) +
stat_slab(alpha = 0.5) +
scale_fill_manual(values = c("darkslateblue", "coral"), name = "Contrast") +
labs(x = "Difference (Hormone - Control)", y = "Density")median_hdi(post, .width = 0.89)# A tibble: 1 × 9
`F: Horm. vs. C.` `F: Horm. vs. C..lower` `F: Horm. vs. C..upper`
<dbl> <dbl> <dbl>
1 15.4 10.2 19.8
# ℹ 6 more variables: `M: Horm. vs. C.` <dbl>, `M: Horm. vs. C..lower` <dbl>,
# `M: Horm. vs. C..upper` <dbl>, .width <dbl>, .point <chr>, .interval <chr>
# FIXME
# Here's how you would do this analysis using the regular interactions
# coding with *.
# This generates the prior prediction and converts to the group posteriors
# The priors are really weak, because there is so little data to learn
# from.
PP <- brm(
Calcium ~ Sex * Treatment,
data = BP,
prior = c(
prior(normal(15, 20), class = Intercept),
prior(normal(0, 10), coef = SexMale),
prior(normal(0, 20), coef = TreatmentHormone),
prior(normal(0, 10), coef = SexMale:TreatmentHormone),
prior(normal(0, 15), class = sigma)
),
refresh = 0,
sample_prior = "only"
) |>
spread_draws(
b_Intercept,
b_SexMale,
b_TreatmentHormone,
`b_SexMale:TreatmentHormone`
) |>
mutate(
Female_None = b_Intercept,
Female_Hormone = b_Intercept + b_TreatmentHormone,
Male_None = b_Intercept + b_SexMale,
Male_Hormone = b_Intercept + b_SexMale + `b_SexMale:TreatmentHormone`,
.keep = "none"
)Start sampling
Running MCMC with 4 sequential chains...
Chain 1 finished in 0.0 seconds.
Chain 2 finished in 0.0 seconds.
Chain 3 finished in 0.0 seconds.
Chain 4 finished in 0.0 seconds.
All 4 chains finished successfully.
Mean chain execution time: 0.0 seconds.
Total execution time: 0.6 seconds.
PP_long <- pivot_longer(
PP,
cols = everything(),
names_to = "Sex_Trt",
values_to = "Calcium"
) |>
separate(col = Sex_Trt, into = c("Sex", "Treatment"), sep = "_")
# See how the Female-None group has the lowest variance, and the Male-Hormone
# group has the highest variance
PP_long |>
group_by(Sex, Treatment) |>
summarize(mean_Calcium = mean(Calcium), var_Calcium = var(Calcium))`summarise()` has grouped output by 'Sex'. You can override using the `.groups`
argument.
# A tibble: 4 × 4
# Groups: Sex [2]
Sex Treatment mean_Calcium var_Calcium
<chr> <chr> <dbl> <dbl>
1 Female Hormone 15.1 538.
2 Female None 15.1 579.
3 Male Hormone 15.1 628.
4 Male None 15.1 575.
# A kind of prior predictive distribution plot
ggplot() +
geom_density(data = PP_long, aes(Calcium)) +
geom_point(
data = BP,
aes(x = Calcium, y = 0),
shape = 21,
fill = "#9ECAE1",
size = 3
) +
facet_grid(Sex ~ Treatment)# Fit the model
fm <- brm(
Calcium ~ Sex * Treatment,
data = BP,
prior = c(
prior(normal(15, 20), class = Intercept),
prior(normal(0, 10), coef = SexMale),
prior(normal(0, 20), coef = TreatmentHormone),
prior(normal(0, 10), coef = SexMale:TreatmentHormone),
prior(normal(0, 15), class = sigma)
),
refresh = 0,
iter = 5e3
)Start sampling
Running MCMC with 4 sequential chains...
Chain 1 finished in 0.0 seconds.
Chain 2 finished in 0.0 seconds.
Chain 3 finished in 0.0 seconds.
Chain 4 finished in 0.0 seconds.
All 4 chains finished successfully.
Mean chain execution time: 0.0 seconds.
Total execution time: 0.6 seconds.
prior_summary(fm) prior class coef group resp dpar nlpar lb ub
(flat) b
normal(0, 10) b SexMale
normal(0, 10) b SexMale:TreatmentHormone
normal(0, 20) b TreatmentHormone
normal(15, 20) Intercept
normal(0, 15) sigma 0
tag source
default
user
user
user
user
user
# These all look fine.
powerscale_plot_dens(
powerscale_sequence(fm),
quantity = c("mean", "sd"),
variable = c("b_Intercept", "b_SexMale")
)powerscale_plot_dens(
powerscale_sequence(fm),
quantity = c("mean", "sd"),
variable = c("b_TreatmentHormone", "b_SexMale:TreatmentHormone")
)powerscale_sensitivity(fm)Sensitivity based on cjs_dist
Prior selection: all priors
Likelihood selection: all data
variable prior likelihood diagnosis
b_Intercept 0.011 0.103 -
b_SexMale 0.014 0.088 -
b_TreatmentHormone 0.022 0.120 -
b_SexMale:TreatmentHormone 0.026 0.094 -
sigma 0.013 0.300 -
Intercept 0.004 0.110 -
summary(fm) Family: gaussian
Links: mu = identity
Formula: Calcium ~ Sex * Treatment
Data: BP (Number of observations: 20)
Draws: 4 chains, each with iter = 5000; warmup = 2500; thin = 1;
total post-warmup draws = 10000
Regression Coefficients:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
Intercept 14.99 2.03 10.95 19.02 1.00 5686
SexMale -2.85 2.74 -8.28 2.58 1.00 5147
TreatmentHormone 17.29 2.83 11.53 22.81 1.00 5189
SexMale:TreatmentHormone -1.62 3.78 -9.04 5.83 1.00 4672
Tail_ESS
Intercept 6602
SexMale 5891
TreatmentHormone 5841
SexMale:TreatmentHormone 5870
Further Distributional Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma 4.65 0.89 3.29 6.80 1.00 6379 6190
Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
# Extract posterior and
post <- fm |>
spread_draws(
b_Intercept,
b_SexMale,
b_TreatmentHormone,
`b_SexMale:TreatmentHormone`
) |>
mutate(
Female_None = b_Intercept,
Female_Hormone = b_Intercept + b_TreatmentHormone,
Male_None = b_Intercept + b_SexMale,
Male_Hormone = b_Intercept + b_SexMale + `b_SexMale:TreatmentHormone`,
.keep = "none"
)
# The Male-Hormone group retains the wider variance in the posterior, even
# though it does not in the observed data.
pivot_longer(
post,
cols = everything(),
names_to = "Sex_Trt",
values_to = "Calcium"
) |>
separate(col = Sex_Trt, into = c("Sex", "Treatment"), sep = "_") |>
group_by(Sex, Treatment) |>
summarize(mean_Calcium = mean(Calcium), var_Calcium = var(Calcium))`summarise()` has grouped output by 'Sex'. You can override using the `.groups`
argument.
# A tibble: 4 × 4
# Groups: Sex [2]
Sex Treatment mean_Calcium var_Calcium
<chr> <chr> <dbl> <dbl>
1 Female Hormone 32.3 4.29
2 Female None 15.0 4.14
3 Male Hormone 10.5 11.6
4 Male None 12.1 4.18
# Roughly equal variances in the observed data
BP |>
group_by(Sex, Treatment) |>
summarize(mean_Calcium = mean(Calcium), var_Calcium = var(Calcium))`summarise()` has grouped output by 'Sex'. You can override using the `.groups`
argument.
# A tibble: 4 × 4
# Groups: Sex [2]
Sex Treatment mean_Calcium var_Calcium
<fct> <fct> <dbl> <dbl>
1 Female None 14.9 17.1
2 Female Hormone 32.5 21.8
3 Male None 12.1 18.0
4 Male Hormone 27.8 18.4
median_hdi(post, .width = 0.89)# A tibble: 1 × 15
Female_None Female_None.lower Female_None.upper Female_Hormone
<dbl> <dbl> <dbl> <dbl>
1 15.0 11.8 18.2 32.3
# ℹ 11 more variables: Female_Hormone.lower <dbl>, Female_Hormone.upper <dbl>,
# Male_None <dbl>, Male_None.lower <dbl>, Male_None.upper <dbl>,
# Male_Hormone <dbl>, Male_Hormone.lower <dbl>, Male_Hormone.upper <dbl>,
# .width <dbl>, .point <chr>, .interval <chr>
Multiple continuous predictors
Working with multiple continuous predictors also poses some unique challenges (not to mention continuous predictors with interactions). Visualization in particular is not straightforward, because, unless you want a 3D plot, you can’t plot 3 continuous variable (1 outcome + 2 predictors) simultaneously. Options include making separate plots, coloring by one predictor by the other, or choosing specific values at which to visualize the data. And usually doing these reciprocally for the two predictors.
To work through this example, we will use the (apparent) trade-off between fat content and lactose content in mammal milk. We used this example in Quantitative Methods 1 to show how multiple regression is actually working.
Load the data in Milk.xlsx, select the columns kcal.per.g, perc.fat, perc.lactose, rename them to Milk_energy, Fat, and Lactose. We will predict the first by the additive effects of the latter two.
There are some missing values in the data, so drop any rows with NA. These are comparative data for different species of primates, but we will ignore those relationships for this analysis.
# FIXME
#| warning: false
MM <- readxl::read_excel("../Data/Milk.xlsx") |>
select(kcal.per.g, perc.fat, perc.lactose) |>
drop_na() |>
rename(Milk_energy = kcal.per.g, Fat = perc.fat, Lactose = perc.lactose)Make two plots, one where energy is predicted by fat and the other by lactose.
# FIXME
p1 <- plot_grid(
ggplot(MM, aes(Fat, Milk_energy)) + geom_point(),
ggplot(MM, aes(Lactose, Milk_energy)) + geom_point(),
ncol = 2
)
p1You will see that they vary inversely. As fat goes up, lactose goes down. Because there is a finite percentage (100%) of what milk can be made of. As one goes up the other goes down. The third component, protein (mostly casein), makes up the last component. We are ignoring protein.
If you check the correlation between fat and lactose, you will see it’s large (\(r \approx\) -0.94). In a frequentist regression, you might be worried about multicollinearity in this case.
# FIXME
cor(MM$Fat, MM$Lactose)[1] -0.9416373
Model specification
\[\begin{align} \mathrm{Milk\_energy} & \sim Normal(\mu, \sigma) \\ \mu & = b0 + b1 \mathrm{Fat} + b2 \mathrm{Lactose} \\ \end{align}\]
Prior specification and prior predictive check
# FIXME
PP <- brm(
Milk_energy ~ Fat + Lactose,
data = MM,
prior = c(
prior(normal(0, 0.05), class = b),
prior(normal(0, 5), class = Intercept),
prior(normal(0, 5), class = sigma)
),
sample_prior = "only",
refresh = 0
)Start sampling
Running MCMC with 4 sequential chains...
Chain 1 finished in 0.0 seconds.
Chain 2 finished in 0.0 seconds.
Chain 3 finished in 0.0 seconds.
Chain 4 finished in 0.0 seconds.
All 4 chains finished successfully.
Mean chain execution time: 0.0 seconds.
Total execution time: 0.6 seconds.
Warning: 3 of 4000 (0.0%) transitions ended with a divergence.
See https://mc-stan.org/misc/warnings for details.
pp_check(PP, ndraws = 100)pp_check(PP, type = "stat", stat = "median", binwidth = 1)Using all posterior draws for ppc type 'stat' by default.
Final model specification
\[\begin{align} \mathrm{Milk\_energy} & \sim Normal(\mu, \sigma) \\ \mu & = b0 + b1 \mathrm{Fat} + b2 \mathrm{Lactose} \\ b0 & \sim Normal(0, 5) \\ b1 & \sim Normal(0, 0.05) \\ b2 & \sim Normal(0, 0.05) \\ sigma & \sim HalfNormal(0, 5) \end{align}\]
Sampling
# FIXME
fm <- brm(
Milk_energy ~ Fat + Lactose,
data = MM,
prior = c(
prior(normal(0, 0.05), class = b),
prior(normal(0, 5), class = Intercept),
prior(normal(0, 5), class = sigma)
),
refresh = 0,
iter = 5e3
)Start sampling
Running MCMC with 4 sequential chains...
Chain 1 finished in 0.1 seconds.
Chain 2 finished in 0.1 seconds.
Chain 3 finished in 0.1 seconds.
Chain 4 finished in 0.1 seconds.
All 4 chains finished successfully.
Mean chain execution time: 0.1 seconds.
Total execution time: 0.7 seconds.
prior_summary(fm) prior class coef group resp dpar nlpar lb ub tag source
normal(0, 0.05) b user
normal(0, 0.05) b Fat (vectorized)
normal(0, 0.05) b Lactose (vectorized)
normal(0, 5) Intercept user
normal(0, 5) sigma 0 user
Diagnostics
# FIXME
mcmc_combo(fm, regex_pars = "^b")mcmc_rank_overlay(fm, regex_pars = "^b")powerscale_plot_dens(
powerscale_sequence(fm),
quantity = c("mean", "sd"),
variable = "b_Intercept"
)powerscale_plot_dens(
powerscale_sequence(fm),
quantity = c("mean", "sd"),
variable = c("b_Fat", "b_Lactose")
)powerscale_plot_quantities(
powerscale_sequence(fm),
quantity = c("mean", "sd"),
variable = c("b_Intercept", "b_Fat", "b_Lactose")
)powerscale_sensitivity(fm)Sensitivity based on cjs_dist
Prior selection: all priors
Likelihood selection: all data
variable prior likelihood diagnosis
b_Intercept 0.001 0.096 -
b_Fat 0.001 0.096 -
b_Lactose 0.001 0.093 -
sigma 0.000 0.209 -
Intercept 0.000 0.089 -
Posterior predictive simulation
Use pp_check() to make a density plot of the observed data superimposed on draws from the posterior.
# FIXME
pp_check(fm, ndraws = 100)To visualize the effect of the two continuous predictors, we’ll have to get creative. Here are the steps:
- Make a grid of observations for prediction. Make a sequence of 200 values between 3 and 56 for
Fat. Specify only three values forLactose: 30, 50, and 70. Each value ofFatwill be associated with three levels ofLactose. - Generate the posterior predictive distributions use
posterior_epred()and the new data you just created. - Calculate the median and 89% HDI intervals using
mutate()like we did in the lecture slides.
# FIXME
pred_values <- crossing(
Fat = seq(3, 56, length.out = 200),
Lactose = c(30, 50, 70)
)
p_pred <- posterior_epred(fm, newdata = pred_values)
pred_values <- pred_values |>
mutate(
Q50 = apply(p_pred, MARGIN = 2, FUN = quantile, prob = 0.5),
Q5.5 = apply(p_pred, MARGIN = 2, FUN = quantile, prob = 0.055),
Q94.5 = apply(p_pred, MARGIN = 2, FUN = quantile, prob = 0.945),
Lactose = factor(Lactose)
)
pred_values# A tibble: 600 × 5
Fat Lactose Q50 Q5.5 Q94.5
<dbl> <fct> <dbl> <dbl> <dbl>
1 3 30 0.749 0.540 0.968
2 3 50 0.576 0.446 0.710
3 3 70 0.404 0.343 0.463
4 3.27 30 0.750 0.542 0.968
5 3.27 50 0.577 0.447 0.710
6 3.27 70 0.404 0.345 0.463
7 3.53 30 0.750 0.544 0.967
8 3.53 50 0.577 0.449 0.709
9 3.53 70 0.405 0.346 0.462
10 3.80 30 0.751 0.545 0.966
# ℹ 590 more rows
You should have a tibble of 600 x 5 columns, with columns for Fat, Lactose, Q50, Q5.5, and Q94.5.
Make a ribbon plot of the 89% interval, add a line for the median, and facet by Lactose in 3 columns. You should be able to see what the model predicts for milk energy as a function of fat at the three levels of lactose.
It will take some staring at this plot to make sense of it. Pay particular attention to the places where the model is pretty sure (narrow bands) or unsure (wide bands).
If you make a composite plot with the pair of scatterplots from the first chunk in this example in one row and this new plot in row 2, it might help to make sense of the output.
# FIXME
p2 <- ggplot() +
geom_ribbon(
data = pred_values,
aes(x = Fat, ymin = Q5.5, ymax = Q94.5, fill = Lactose),
alpha = 0.25
) +
geom_line(data = pred_values, aes(x = Fat, y = Q50, color = Lactose)) +
facet_grid(. ~ Lactose) +
scale_color_viridis_d(option = "D") +
scale_fill_viridis_d(option = "D") +
labs(x = "Fat Percentage", y = "Milk Energy")
plot_grid(p1, p2, nrow = 2)Summarizing the posterior
Summarize the posterior however you think is appropriate.
# FIXME
summary(fm) |> print(digits = 4) Family: gaussian
Links: mu = identity
Formula: Milk_energy ~ Fat + Lactose
Data: MM (Number of observations: 29)
Draws: 4 chains, each with iter = 5000; warmup = 2500; thin = 1;
total post-warmup draws = 10000
Regression Coefficients:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept 1.0043 0.2213 0.5632 1.4397 1.0008 4825 5017
Fat 0.0020 0.0026 -0.0032 0.0072 1.0007 4970 5295
Lactose -0.0087 0.0027 -0.0140 -0.0034 1.0007 4921 4903
Further Distributional Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma 0.0676 0.0101 0.0514 0.0910 1.0000 5184 5029
Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
Footnotes
Kallioinen, N., T. Paananen, P.-C. Bürkner, and A. Vehtari. 2024. Detecting and diagnosing prior and likelihood sensitivity with power-scaling. Stat. Comput. 34.↩︎